{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "5f427d5e",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-05-16T12:21:27.190861Z",
     "start_time": "2021-05-16T12:21:26.464606Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO: Pandarallel will run on 20 workers.\n",
      "INFO: Pandarallel will use Memory file system to transfer data between the main process and workers.\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoTokenizer\n",
    "from transformers import AutoModel\n",
    "import torch\n",
    "import pandas as pd\n",
    "from tqdm import tqdm\n",
    "import numpy as np\n",
    "import os\n",
    "import pickle\n",
    "from pandarallel import pandarallel\n",
    "import warnings\n",
    "\n",
    "pd.set_option('display.max_columns', None)\n",
    "pd.set_option('display.max_rows', None)\n",
    "pandarallel.initialize()\n",
    "warnings.filterwarnings('ignore')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "1d01bc68",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-05-16T12:21:27.197000Z",
     "start_time": "2021-05-16T12:21:27.193297Z"
    }
   },
   "outputs": [],
   "source": [
    "os.makedirs('data/tmp', exist_ok=True)\n",
    "os.makedirs('data/embedding', exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "1c5ab012",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-05-16T12:21:27.324528Z",
     "start_time": "2021-05-16T12:21:27.199750Z"
    }
   },
   "outputs": [],
   "source": [
    "DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "4ad125f0",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-05-16T12:21:27.434638Z",
     "start_time": "2021-05-16T12:21:27.326665Z"
    },
    "code_folding": [
     0
    ]
   },
   "outputs": [],
   "source": [
    "class Encoder():\n",
    "    def __init__(self,\n",
    "                 sentences,\n",
    "                 embeddings,\n",
    "                 key_name,\n",
    "                 prefix,\n",
    "                 keys_sentences_map=None):\n",
    "        self.sentences = sentences\n",
    "        self.embeddings = embeddings\n",
    "        self.key_name = key_name\n",
    "        self.prefix = prefix\n",
    "\n",
    "        if keys_sentences_map is not None:\n",
    "            self.keys_sentences_map = keys_sentences_map\n",
    "        else:\n",
    "            self.keys_sentences_map = dict(zip(sentences, sentences))\n",
    "\n",
    "        sentences_embeddings_map = dict(zip(sentences, embeddings))\n",
    "        self.keys_embeddings_map = {}\n",
    "        for key, sentence in self.keys_sentences_map.items():\n",
    "            self.keys_embeddings_map[key] = sentences_embeddings_map[sentence]\n",
    "\n",
    "    def get_embeddings(self, normalize=False):\n",
    "        if normalize and self.keys_normalize_embeddings_map is not None:\n",
    "            keys_embeddings_map = self.keys_normalize_embeddings_map\n",
    "        else:\n",
    "            keys_embeddings_map = self.keys_embeddings_map\n",
    "\n",
    "        emb_size = len(list(keys_embeddings_map.values())[0])\n",
    "\n",
    "        data_list = []\n",
    "        for key, embedding in keys_embeddings_map.items():\n",
    "            data_list.append([key] + list(embedding))\n",
    "\n",
    "        df_emb = pd.DataFrame(data_list)\n",
    "        df_emb.columns = [self.key_name] + [\n",
    "            '{}_emb_{}'.format(self.prefix, i) for i in range(emb_size)\n",
    "        ]\n",
    "\n",
    "        return df_emb\n",
    "\n",
    "    def get_embedding(self, key, normalize=False):\n",
    "        try:\n",
    "            if normalize and self.keys_normalize_embeddings_map is not None:\n",
    "                return self.keys_normalize_embeddings_map[key]\n",
    "            else:\n",
    "                return self.keys_embeddings_map[key]\n",
    "        except Exception:\n",
    "            return None\n",
    "\n",
    "    def transform_and_normalize(self, kernel, bias, n_components=None):\n",
    "        \"\"\"应用变换，然后标准化\n",
    "        \"\"\"\n",
    "        if n_components is not None:\n",
    "            kernel = kernel[:, :n_components]\n",
    "\n",
    "        if not (kernel is None or bias is None):\n",
    "            vecs = (self.embeddings + bias).dot(kernel)\n",
    "        else:\n",
    "            vecs = vecs\n",
    "\n",
    "        vecs = vecs / (vecs**2).sum(axis=1, keepdims=True)**0.5\n",
    "\n",
    "        sentences_embeddings_map = dict(zip(self.sentences, vecs))\n",
    "        self.keys_normalize_embeddings_map = {}\n",
    "        for key, sentence in self.keys_sentences_map.items():\n",
    "            self.keys_normalize_embeddings_map[key] = sentences_embeddings_map[\n",
    "                sentence]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "275c297e",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-05-16T12:21:27.569934Z",
     "start_time": "2021-05-16T12:21:27.437197Z"
    },
    "code_folding": [
     0
    ]
   },
   "outputs": [],
   "source": [
    "def build_model(path):\n",
    "    tokenizer = AutoTokenizer.from_pretrained(path)\n",
    "    model = AutoModel.from_pretrained(path)\n",
    "    model = model.to(DEVICE)\n",
    "    return tokenizer, model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "4c3ca7b8",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-05-16T12:21:27.677596Z",
     "start_time": "2021-05-16T12:21:27.572566Z"
    },
    "code_folding": []
   },
   "outputs": [],
   "source": [
    "def sent_to_vec(sent, tokenizer, model, pooling, max_length):\n",
    "    with torch.no_grad():\n",
    "        inputs = tokenizer(sent,\n",
    "                           return_tensors=\"pt\",\n",
    "                           padding=True,\n",
    "                           truncation=True,\n",
    "                           max_length=max_length)\n",
    "        inputs['input_ids'] = inputs['input_ids'].to(DEVICE)\n",
    "        inputs['token_type_ids'] = inputs['token_type_ids'].to(DEVICE)\n",
    "        inputs['attention_mask'] = inputs['attention_mask'].to(DEVICE)\n",
    "\n",
    "        hidden_states = model(**inputs,\n",
    "                              return_dict=True,\n",
    "                              output_hidden_states=True).hidden_states\n",
    "\n",
    "        if pooling == 'first_last_avg':\n",
    "            output_hidden_state = (hidden_states[-1] +\n",
    "                                   hidden_states[1]).mean(dim=1)\n",
    "        elif pooling == 'last_avg':\n",
    "            output_hidden_state = (hidden_states[-1]).mean(dim=1)\n",
    "        elif pooling == 'last2avg':\n",
    "            output_hidden_state = (hidden_states[-1] +\n",
    "                                   hidden_states[-2]).mean(dim=1)\n",
    "        elif pooling == 'cls':\n",
    "            output_hidden_state = (hidden_states[-1])[:, 0, :]\n",
    "        else:\n",
    "            raise Exception(\"unknown pooling {}\".format(POOLING))\n",
    "\n",
    "        vec = output_hidden_state.cpu().numpy()[0]\n",
    "    return vec"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "f1c8de93",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-05-16T12:21:27.788011Z",
     "start_time": "2021-05-16T12:21:27.680147Z"
    },
    "code_folding": []
   },
   "outputs": [],
   "source": [
    "def sents_to_vecs(sents, tokenizer, model, pooling, max_length, verbose=True):\n",
    "    vecs = []\n",
    "    if verbose:\n",
    "        sents = tqdm(sents)\n",
    "    for sent in sents:\n",
    "        vec = sent_to_vec(sent, tokenizer, model, pooling, max_length)\n",
    "        vecs.append(vec)\n",
    "    assert len(sents) == len(vecs)\n",
    "    vecs = np.array(vecs)\n",
    "    return vecs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "ec13e89c",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-05-16T12:21:27.915916Z",
     "start_time": "2021-05-16T12:21:27.791095Z"
    },
    "code_folding": [
     0
    ]
   },
   "outputs": [],
   "source": [
    "def compute_kernel_bias(vecs):\n",
    "    \"\"\"计算kernel和bias\n",
    "    最后的变换：y = (x + bias).dot(kernel)\n",
    "    \"\"\"\n",
    "    vecs = np.concatenate(vecs, axis=0)\n",
    "    mu = vecs.mean(axis=0, keepdims=True)\n",
    "    cov = np.cov(vecs.T)\n",
    "    u, s, vh = np.linalg.svd(cov)\n",
    "    W = np.dot(u, np.diag(s**0.5))\n",
    "    W = np.linalg.inv(W.T)\n",
    "    return W, -mu"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "a05e83d5",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-05-16T12:21:31.304643Z",
     "start_time": "2021-05-16T12:21:27.920504Z"
    }
   },
   "outputs": [],
   "source": [
    "# 加载模型\n",
    "path = 'data/pretrain_models/ernie'\n",
    "tokenizer, model = build_model(path)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e3094989",
   "metadata": {},
   "source": [
    "# 生成embedding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "e43ed25d",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-05-16T12:21:31.310720Z",
     "start_time": "2021-05-16T12:21:31.307640Z"
    }
   },
   "outputs": [],
   "source": [
    "vecs_list = []\n",
    "pooling = 'cls'\n",
    "max_length = 512"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "d76da610",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-05-16T12:21:31.424797Z",
     "start_time": "2021-05-16T12:21:31.312609Z"
    },
    "code_folding": [],
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# 招聘岗位信息的招聘职位\n",
    "def get_job_title_encoder():\n",
    "    try:\n",
    "        vecs = np.load('data/tmp/job_title_vecs.npy')\n",
    "        with open('data/tmp/job_title_encoder.txt', 'rb') as f:\n",
    "            job_title_encoder = pickle.load(f)\n",
    "\n",
    "    except Exception:\n",
    "        df_recruit = pd.read_csv('raw_data/trainset/recruit.csv')\n",
    "        sentences = df_recruit['JOB_TITLE'].values.tolist()\n",
    "        sentences = list(set(sentences))\n",
    "        vecs = sents_to_vecs(sentences, tokenizer, model, pooling, max_length)\n",
    "        job_title_encoder = Encoder(sentences, vecs, 'JOB_TITLE',\n",
    "                                    'JOB_TITLE_ernie')\n",
    "\n",
    "        np.save('data/tmp/job_title_vecs.npy', vecs)\n",
    "        with open('data/tmp/job_title_encoder.txt', 'wb') as f:\n",
    "            pickle.dump(job_title_encoder, f)\n",
    "\n",
    "    return vecs, job_title_encoder\n",
    "\n",
    "\n",
    "vecs, job_title_encoder = get_job_title_encoder()\n",
    "vecs_list.append(vecs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "1f85b11a",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-05-16T12:21:31.509510Z",
     "start_time": "2021-05-16T12:21:31.427430Z"
    },
    "code_folding": [
     9
    ]
   },
   "outputs": [],
   "source": [
    "def major_clean(x):\n",
    "    if type(x) == float:\n",
    "        return x\n",
    "\n",
    "    x = x.replace('【', '').replace('】', '')\n",
    "    return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "f74a4a25",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-05-16T12:21:31.613544Z",
     "start_time": "2021-05-16T12:21:31.512154Z"
    }
   },
   "outputs": [],
   "source": [
    "# 招聘岗位信息的对应聘者的专业要求\n",
    "def get_recruit_major_encoder():\n",
    "    try:\n",
    "        vecs = np.load('data/tmp/recruit_major_vecs.npy')\n",
    "        with open('data/tmp/recruit_major_encoder.txt', 'rb') as f:\n",
    "            recruit_major_encoder = pickle.load(f)\n",
    "\n",
    "    except Exception:\n",
    "        df_recruit = pd.read_csv('raw_data/trainset/recruit.csv')\n",
    "        df_recruit['MAJOR'].fillna('', inplace=True)\n",
    "        df_recruit['MAJOR'] = df_recruit['MAJOR'].apply(major_clean)\n",
    "        sentences = df_recruit['MAJOR'].values.tolist()\n",
    "        sentences = list(set(sentences))\n",
    "        vecs = sents_to_vecs(sentences, tokenizer, model, pooling, max_length)\n",
    "        recruit_major_encoder = Encoder(sentences, vecs, 'MAJOR',\n",
    "                                        'recruit_MAJOR_ernie')\n",
    "\n",
    "        np.save('data/tmp/recruit_major_vecs.npy', vecs)\n",
    "        with open('data/tmp/recruit_major_encoder.txt', 'wb') as f:\n",
    "            pickle.dump(recruit_major_encoder, f)\n",
    "\n",
    "    return vecs, recruit_major_encoder\n",
    "\n",
    "\n",
    "vecs, recruit_major_encoder = get_recruit_major_encoder()\n",
    "vecs_list.append(vecs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "5409d4f0",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-05-16T12:21:31.737153Z",
     "start_time": "2021-05-16T12:21:31.616305Z"
    }
   },
   "outputs": [],
   "source": [
    "# 求职者基本信息的应聘者专业\n",
    "def get_person_major_encoder():\n",
    "    try:\n",
    "        vecs = np.load('data/tmp/person_major_vecs.npy')\n",
    "        with open('data/tmp/person_major_encoder.txt', 'rb') as f:\n",
    "            person_major_encoder = pickle.load(f)\n",
    "\n",
    "    except Exception:\n",
    "        df_person = pd.read_csv('raw_data/trainset/person.csv')\n",
    "        df_person['MAJOR'].fillna('', inplace=True)\n",
    "        df_person['MAJOR'] = df_person['MAJOR'].apply(major_clean)\n",
    "        sentences = df_person['MAJOR'].values.tolist()\n",
    "        sentences = list(set(sentences))\n",
    "        vecs = sents_to_vecs(sentences, tokenizer, model, pooling, max_length)\n",
    "        person_major_encoder = Encoder(sentences, vecs, 'MAJOR',\n",
    "                                       'person_MAJOR_ernie')\n",
    "\n",
    "        np.save('data/tmp/person_major_vecs.npy', vecs)\n",
    "        with open('data/tmp/person_major_encoder.txt', 'wb') as f:\n",
    "            pickle.dump(person_major_encoder, f)\n",
    "\n",
    "    return vecs, person_major_encoder\n",
    "\n",
    "\n",
    "vecs, person_major_encoder = get_person_major_encoder()\n",
    "vecs_list.append(vecs)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "554389a2",
   "metadata": {},
   "source": [
    "# BERT-whitening\n",
    "https://kexue.fm/archives/8321"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "9c284d37",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-05-16T12:21:32.042115Z",
     "start_time": "2021-05-16T12:21:31.739738Z"
    }
   },
   "outputs": [],
   "source": [
    "kernel, bias = compute_kernel_bias(vecs_list)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f42fa930",
   "metadata": {},
   "source": [
    "# 保存embedding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "03c5ef69",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-05-16T12:21:32.061732Z",
     "start_time": "2021-05-16T12:21:32.043472Z"
    }
   },
   "outputs": [],
   "source": [
    "job_title_encoder.transform_and_normalize(kernel, bias, 30)\n",
    "job_title_embeddings = job_title_encoder.get_embeddings(True)\n",
    "job_title_embeddings.to_pickle('data/embedding/job_title.pkl')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "6f078a3a",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-05-16T12:21:32.199012Z",
     "start_time": "2021-05-16T12:21:32.062949Z"
    }
   },
   "outputs": [],
   "source": [
    "recruit_major_encoder.transform_and_normalize(kernel, bias, 30)\n",
    "recruit_major_embeddings = recruit_major_encoder.get_embeddings(True)\n",
    "recruit_major_embeddings.to_pickle('data/embedding/recruit_major.pkl')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "bdaa1105",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-05-16T12:21:32.328048Z",
     "start_time": "2021-05-16T12:21:32.201866Z"
    }
   },
   "outputs": [],
   "source": [
    "person_major_encoder.transform_and_normalize(kernel, bias, 30)\n",
    "person_major_embeddings = person_major_encoder.get_embeddings(True)\n",
    "person_major_embeddings.to_pickle('data/embedding/person_major.pkl')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "da370dcf",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-05-16T12:21:32.467025Z",
     "start_time": "2021-05-16T12:21:32.330019Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>MAJOR</th>\n",
       "      <th>person_MAJOR_ernie_emb_0</th>\n",
       "      <th>person_MAJOR_ernie_emb_1</th>\n",
       "      <th>person_MAJOR_ernie_emb_2</th>\n",
       "      <th>person_MAJOR_ernie_emb_3</th>\n",
       "      <th>person_MAJOR_ernie_emb_4</th>\n",
       "      <th>person_MAJOR_ernie_emb_5</th>\n",
       "      <th>person_MAJOR_ernie_emb_6</th>\n",
       "      <th>person_MAJOR_ernie_emb_7</th>\n",
       "      <th>person_MAJOR_ernie_emb_8</th>\n",
       "      <th>person_MAJOR_ernie_emb_9</th>\n",
       "      <th>person_MAJOR_ernie_emb_10</th>\n",
       "      <th>person_MAJOR_ernie_emb_11</th>\n",
       "      <th>person_MAJOR_ernie_emb_12</th>\n",
       "      <th>person_MAJOR_ernie_emb_13</th>\n",
       "      <th>person_MAJOR_ernie_emb_14</th>\n",
       "      <th>person_MAJOR_ernie_emb_15</th>\n",
       "      <th>person_MAJOR_ernie_emb_16</th>\n",
       "      <th>person_MAJOR_ernie_emb_17</th>\n",
       "      <th>person_MAJOR_ernie_emb_18</th>\n",
       "      <th>person_MAJOR_ernie_emb_19</th>\n",
       "      <th>person_MAJOR_ernie_emb_20</th>\n",
       "      <th>person_MAJOR_ernie_emb_21</th>\n",
       "      <th>person_MAJOR_ernie_emb_22</th>\n",
       "      <th>person_MAJOR_ernie_emb_23</th>\n",
       "      <th>person_MAJOR_ernie_emb_24</th>\n",
       "      <th>person_MAJOR_ernie_emb_25</th>\n",
       "      <th>person_MAJOR_ernie_emb_26</th>\n",
       "      <th>person_MAJOR_ernie_emb_27</th>\n",
       "      <th>person_MAJOR_ernie_emb_28</th>\n",
       "      <th>person_MAJOR_ernie_emb_29</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td></td>\n",
       "      <td>0.033922</td>\n",
       "      <td>0.016015</td>\n",
       "      <td>-0.079779</td>\n",
       "      <td>0.032491</td>\n",
       "      <td>0.194276</td>\n",
       "      <td>0.181433</td>\n",
       "      <td>-0.021364</td>\n",
       "      <td>0.019380</td>\n",
       "      <td>-0.204926</td>\n",
       "      <td>-0.107664</td>\n",
       "      <td>-0.020242</td>\n",
       "      <td>0.035183</td>\n",
       "      <td>0.269443</td>\n",
       "      <td>-0.102305</td>\n",
       "      <td>-0.013670</td>\n",
       "      <td>-0.260084</td>\n",
       "      <td>0.467201</td>\n",
       "      <td>0.224754</td>\n",
       "      <td>0.101588</td>\n",
       "      <td>-0.268790</td>\n",
       "      <td>-0.169664</td>\n",
       "      <td>-0.233692</td>\n",
       "      <td>0.126631</td>\n",
       "      <td>-0.143620</td>\n",
       "      <td>-0.339964</td>\n",
       "      <td>-0.121349</td>\n",
       "      <td>0.125929</td>\n",
       "      <td>-0.289600</td>\n",
       "      <td>-0.086212</td>\n",
       "      <td>0.068332</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>无机化学</td>\n",
       "      <td>0.161805</td>\n",
       "      <td>0.143702</td>\n",
       "      <td>-0.093986</td>\n",
       "      <td>-0.055504</td>\n",
       "      <td>-0.106657</td>\n",
       "      <td>-0.039275</td>\n",
       "      <td>0.158578</td>\n",
       "      <td>0.218503</td>\n",
       "      <td>-0.066466</td>\n",
       "      <td>-0.029445</td>\n",
       "      <td>0.023388</td>\n",
       "      <td>0.110090</td>\n",
       "      <td>-0.407006</td>\n",
       "      <td>0.180535</td>\n",
       "      <td>-0.170367</td>\n",
       "      <td>-0.233926</td>\n",
       "      <td>0.015842</td>\n",
       "      <td>-0.230791</td>\n",
       "      <td>-0.089471</td>\n",
       "      <td>0.214549</td>\n",
       "      <td>-0.161172</td>\n",
       "      <td>-0.064123</td>\n",
       "      <td>0.236262</td>\n",
       "      <td>-0.294073</td>\n",
       "      <td>0.051690</td>\n",
       "      <td>0.166714</td>\n",
       "      <td>-0.372305</td>\n",
       "      <td>0.003644</td>\n",
       "      <td>0.314602</td>\n",
       "      <td>0.086595</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>核科学与技术</td>\n",
       "      <td>0.163946</td>\n",
       "      <td>0.148324</td>\n",
       "      <td>-0.277979</td>\n",
       "      <td>-0.226509</td>\n",
       "      <td>0.118653</td>\n",
       "      <td>0.126283</td>\n",
       "      <td>-0.246546</td>\n",
       "      <td>0.040637</td>\n",
       "      <td>-0.125602</td>\n",
       "      <td>0.254303</td>\n",
       "      <td>-0.018426</td>\n",
       "      <td>-0.144588</td>\n",
       "      <td>-0.292179</td>\n",
       "      <td>-0.126347</td>\n",
       "      <td>0.343572</td>\n",
       "      <td>-0.240943</td>\n",
       "      <td>0.122548</td>\n",
       "      <td>-0.225315</td>\n",
       "      <td>0.040632</td>\n",
       "      <td>-0.180738</td>\n",
       "      <td>-0.045741</td>\n",
       "      <td>-0.189579</td>\n",
       "      <td>-0.364384</td>\n",
       "      <td>0.041531</td>\n",
       "      <td>-0.088535</td>\n",
       "      <td>0.146077</td>\n",
       "      <td>0.130512</td>\n",
       "      <td>-0.178523</td>\n",
       "      <td>-0.010251</td>\n",
       "      <td>0.011866</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>教育技术学</td>\n",
       "      <td>0.168010</td>\n",
       "      <td>0.316615</td>\n",
       "      <td>-0.187714</td>\n",
       "      <td>-0.108187</td>\n",
       "      <td>0.020457</td>\n",
       "      <td>-0.318753</td>\n",
       "      <td>-0.148261</td>\n",
       "      <td>-0.392003</td>\n",
       "      <td>-0.113287</td>\n",
       "      <td>-0.277270</td>\n",
       "      <td>0.057580</td>\n",
       "      <td>0.172383</td>\n",
       "      <td>-0.116611</td>\n",
       "      <td>-0.278950</td>\n",
       "      <td>-0.207907</td>\n",
       "      <td>0.029294</td>\n",
       "      <td>-0.228712</td>\n",
       "      <td>0.104740</td>\n",
       "      <td>0.031003</td>\n",
       "      <td>-0.156405</td>\n",
       "      <td>-0.080898</td>\n",
       "      <td>0.193279</td>\n",
       "      <td>0.060483</td>\n",
       "      <td>0.290677</td>\n",
       "      <td>0.006571</td>\n",
       "      <td>-0.157274</td>\n",
       "      <td>-0.077795</td>\n",
       "      <td>-0.005406</td>\n",
       "      <td>0.055618</td>\n",
       "      <td>-0.184305</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>公共管理</td>\n",
       "      <td>0.137709</td>\n",
       "      <td>0.136864</td>\n",
       "      <td>0.216789</td>\n",
       "      <td>-0.002276</td>\n",
       "      <td>0.054183</td>\n",
       "      <td>-0.362888</td>\n",
       "      <td>0.016013</td>\n",
       "      <td>-0.018113</td>\n",
       "      <td>0.042418</td>\n",
       "      <td>0.247334</td>\n",
       "      <td>-0.146403</td>\n",
       "      <td>0.135248</td>\n",
       "      <td>0.101729</td>\n",
       "      <td>-0.044444</td>\n",
       "      <td>0.097301</td>\n",
       "      <td>0.196056</td>\n",
       "      <td>-0.305892</td>\n",
       "      <td>-0.231472</td>\n",
       "      <td>0.064632</td>\n",
       "      <td>0.152949</td>\n",
       "      <td>-0.280432</td>\n",
       "      <td>0.244647</td>\n",
       "      <td>-0.125882</td>\n",
       "      <td>0.061134</td>\n",
       "      <td>0.076664</td>\n",
       "      <td>-0.002361</td>\n",
       "      <td>0.007039</td>\n",
       "      <td>0.012406</td>\n",
       "      <td>-0.219187</td>\n",
       "      <td>0.479831</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "    MAJOR  person_MAJOR_ernie_emb_0  person_MAJOR_ernie_emb_1  \\\n",
       "0                          0.033922                  0.016015   \n",
       "1    无机化学                  0.161805                  0.143702   \n",
       "2  核科学与技术                  0.163946                  0.148324   \n",
       "3   教育技术学                  0.168010                  0.316615   \n",
       "4    公共管理                  0.137709                  0.136864   \n",
       "\n",
       "   person_MAJOR_ernie_emb_2  person_MAJOR_ernie_emb_3  \\\n",
       "0                 -0.079779                  0.032491   \n",
       "1                 -0.093986                 -0.055504   \n",
       "2                 -0.277979                 -0.226509   \n",
       "3                 -0.187714                 -0.108187   \n",
       "4                  0.216789                 -0.002276   \n",
       "\n",
       "   person_MAJOR_ernie_emb_4  person_MAJOR_ernie_emb_5  \\\n",
       "0                  0.194276                  0.181433   \n",
       "1                 -0.106657                 -0.039275   \n",
       "2                  0.118653                  0.126283   \n",
       "3                  0.020457                 -0.318753   \n",
       "4                  0.054183                 -0.362888   \n",
       "\n",
       "   person_MAJOR_ernie_emb_6  person_MAJOR_ernie_emb_7  \\\n",
       "0                 -0.021364                  0.019380   \n",
       "1                  0.158578                  0.218503   \n",
       "2                 -0.246546                  0.040637   \n",
       "3                 -0.148261                 -0.392003   \n",
       "4                  0.016013                 -0.018113   \n",
       "\n",
       "   person_MAJOR_ernie_emb_8  person_MAJOR_ernie_emb_9  \\\n",
       "0                 -0.204926                 -0.107664   \n",
       "1                 -0.066466                 -0.029445   \n",
       "2                 -0.125602                  0.254303   \n",
       "3                 -0.113287                 -0.277270   \n",
       "4                  0.042418                  0.247334   \n",
       "\n",
       "   person_MAJOR_ernie_emb_10  person_MAJOR_ernie_emb_11  \\\n",
       "0                  -0.020242                   0.035183   \n",
       "1                   0.023388                   0.110090   \n",
       "2                  -0.018426                  -0.144588   \n",
       "3                   0.057580                   0.172383   \n",
       "4                  -0.146403                   0.135248   \n",
       "\n",
       "   person_MAJOR_ernie_emb_12  person_MAJOR_ernie_emb_13  \\\n",
       "0                   0.269443                  -0.102305   \n",
       "1                  -0.407006                   0.180535   \n",
       "2                  -0.292179                  -0.126347   \n",
       "3                  -0.116611                  -0.278950   \n",
       "4                   0.101729                  -0.044444   \n",
       "\n",
       "   person_MAJOR_ernie_emb_14  person_MAJOR_ernie_emb_15  \\\n",
       "0                  -0.013670                  -0.260084   \n",
       "1                  -0.170367                  -0.233926   \n",
       "2                   0.343572                  -0.240943   \n",
       "3                  -0.207907                   0.029294   \n",
       "4                   0.097301                   0.196056   \n",
       "\n",
       "   person_MAJOR_ernie_emb_16  person_MAJOR_ernie_emb_17  \\\n",
       "0                   0.467201                   0.224754   \n",
       "1                   0.015842                  -0.230791   \n",
       "2                   0.122548                  -0.225315   \n",
       "3                  -0.228712                   0.104740   \n",
       "4                  -0.305892                  -0.231472   \n",
       "\n",
       "   person_MAJOR_ernie_emb_18  person_MAJOR_ernie_emb_19  \\\n",
       "0                   0.101588                  -0.268790   \n",
       "1                  -0.089471                   0.214549   \n",
       "2                   0.040632                  -0.180738   \n",
       "3                   0.031003                  -0.156405   \n",
       "4                   0.064632                   0.152949   \n",
       "\n",
       "   person_MAJOR_ernie_emb_20  person_MAJOR_ernie_emb_21  \\\n",
       "0                  -0.169664                  -0.233692   \n",
       "1                  -0.161172                  -0.064123   \n",
       "2                  -0.045741                  -0.189579   \n",
       "3                  -0.080898                   0.193279   \n",
       "4                  -0.280432                   0.244647   \n",
       "\n",
       "   person_MAJOR_ernie_emb_22  person_MAJOR_ernie_emb_23  \\\n",
       "0                   0.126631                  -0.143620   \n",
       "1                   0.236262                  -0.294073   \n",
       "2                  -0.364384                   0.041531   \n",
       "3                   0.060483                   0.290677   \n",
       "4                  -0.125882                   0.061134   \n",
       "\n",
       "   person_MAJOR_ernie_emb_24  person_MAJOR_ernie_emb_25  \\\n",
       "0                  -0.339964                  -0.121349   \n",
       "1                   0.051690                   0.166714   \n",
       "2                  -0.088535                   0.146077   \n",
       "3                   0.006571                  -0.157274   \n",
       "4                   0.076664                  -0.002361   \n",
       "\n",
       "   person_MAJOR_ernie_emb_26  person_MAJOR_ernie_emb_27  \\\n",
       "0                   0.125929                  -0.289600   \n",
       "1                  -0.372305                   0.003644   \n",
       "2                   0.130512                  -0.178523   \n",
       "3                  -0.077795                  -0.005406   \n",
       "4                   0.007039                   0.012406   \n",
       "\n",
       "   person_MAJOR_ernie_emb_28  person_MAJOR_ernie_emb_29  \n",
       "0                  -0.086212                   0.068332  \n",
       "1                   0.314602                   0.086595  \n",
       "2                  -0.010251                   0.011866  \n",
       "3                   0.055618                  -0.184305  \n",
       "4                  -0.219187                   0.479831  "
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "person_major_embeddings.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e9d42cf5",
   "metadata": {},
   "source": [
    "# 计算匹配度"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "4e685135",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-05-16T12:21:32.775360Z",
     "start_time": "2021-05-16T12:21:32.468152Z"
    }
   },
   "outputs": [],
   "source": [
    "df_train = pd.read_csv('raw_data/trainset/recruit_folder.csv')\n",
    "df_test = pd.read_csv('raw_data/testset/recruit_folder.csv')\n",
    "df_test['LABEL'] = np.nan\n",
    "df_feature = df_train.append(df_test, sort=False)\n",
    "df_recruit = pd.read_csv('raw_data/trainset/recruit.csv')\n",
    "df_feature = df_feature.merge(df_recruit[['RECRUIT_ID', 'MAJOR']],\n",
    "                              how='left',\n",
    "                              on='RECRUIT_ID')\n",
    "df_feature.rename({'MAJOR': 'recruit_MAJOR'}, axis=1, inplace=True)\n",
    "df_person = pd.read_csv('raw_data/trainset/person.csv')\n",
    "df_feature = df_feature.merge(df_person[['PERSON_ID', 'MAJOR']],\n",
    "                              how='left',\n",
    "                              on='PERSON_ID')\n",
    "df_feature.rename({'MAJOR': 'person_MAJOR'}, axis=1, inplace=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "67c0ff33",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-05-16T12:21:32.783148Z",
     "start_time": "2021-05-16T12:21:32.776579Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>RECRUIT_ID</th>\n",
       "      <th>PERSON_ID</th>\n",
       "      <th>LABEL</th>\n",
       "      <th>recruit_MAJOR</th>\n",
       "      <th>person_MAJOR</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>825081</td>\n",
       "      <td>6256839</td>\n",
       "      <td>0.0</td>\n",
       "      <td>工业自动化</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>772899</td>\n",
       "      <td>5413605</td>\n",
       "      <td>0.0</td>\n",
       "      <td>旅游管理</td>\n",
       "      <td>文秘</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>795668</td>\n",
       "      <td>5219796</td>\n",
       "      <td>0.0</td>\n",
       "      <td>NaN</td>\n",
       "      <td>财政学（含税收学）</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>769754</td>\n",
       "      <td>5700693</td>\n",
       "      <td>0.0</td>\n",
       "      <td>NaN</td>\n",
       "      <td>计算机应用技术</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>773645</td>\n",
       "      <td>6208645</td>\n",
       "      <td>0.0</td>\n",
       "      <td>汽车工程</td>\n",
       "      <td>计算机应用技术</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   RECRUIT_ID  PERSON_ID  LABEL recruit_MAJOR person_MAJOR\n",
       "0      825081    6256839    0.0         工业自动化          NaN\n",
       "1      772899    5413605    0.0          旅游管理           文秘\n",
       "2      795668    5219796    0.0           NaN    财政学（含税收学）\n",
       "3      769754    5700693    0.0           NaN      计算机应用技术\n",
       "4      773645    6208645    0.0          汽车工程      计算机应用技术"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_feature.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "75a71cac",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-05-16T12:21:32.876316Z",
     "start_time": "2021-05-16T12:21:32.784215Z"
    }
   },
   "outputs": [],
   "source": [
    "def consine(vector1, vector2):\n",
    "    if type(vector1) != np.ndarray or type(vector2) != np.ndarray:\n",
    "        return -1\n",
    "    distance = np.dot(vector1, vector2) / \\\n",
    "        (np.linalg.norm(vector1)*(np.linalg.norm(vector2)))\n",
    "    return distance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "16dab53e",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-05-16T12:21:34.483997Z",
     "start_time": "2021-05-16T12:21:32.879917Z"
    },
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "df_feature['recruit_person_MAJOR_score'] = df_feature[[\n",
    "    'recruit_MAJOR', 'person_MAJOR'\n",
    "]].apply(lambda x: consine(\n",
    "    recruit_major_encoder.get_embedding(x['recruit_MAJOR'], True),\n",
    "    person_major_encoder.get_embedding(x['person_MAJOR'], True)),\n",
    "         axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "78bcfcdf",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-05-16T12:21:34.491790Z",
     "start_time": "2021-05-16T12:21:34.485365Z"
    }
   },
   "outputs": [],
   "source": [
    "df_feature[['RECRUIT_ID', 'PERSON_ID',\n",
    "            'recruit_person_MAJOR_score']].to_pickle('data/score.pkl')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5f846b67",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:dm] *",
   "language": "python",
   "name": "conda-env-dm-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
